#  -*- coding: utf-8 -*-
#  Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import torch
from torch import nn as nn
from torch.nn import functional as F

from msmodelslim.core import QParam, QStorage, calculate_qparam, QDType, fake_quantize, dequantize
from msmodelslim.core.QAL import QABCRegistry, QScope, QScheme
from msmodelslim.quant.ir import AutoFakeQuantLinear
from msmodelslim.utils.logging import logger_setter

from msmodelslim.quant.ir import mxfp8_per_block_sym, mxfp4_per_block_sym
from msmodelslim.quant.observer import MsMinMaxBlockObserver, MinMaxBlockObserverConfig

from msmodelslim.quant.ir.utils import reshape_to_blocks, undo_reshape_to_blocks


@QABCRegistry.multi_register(
    dispatch_key=[
        (mxfp8_per_block_sym, mxfp4_per_block_sym),
    ],
    abc_type=AutoFakeQuantLinear
)
@logger_setter('msmodelslim.quant.ir.w4a8_mx_dynamic_per_block')
class W4A8MXDynamicPerBlockFakeQuantLinear(AutoFakeQuantLinear):

    def __init__(
            self,
            x_q_param: QParam,
            w_q_param: QParam,
            w_q: QStorage,
            bias: torch.Tensor
    ):
        super().__init__()
        self.w_scheme = w_q_param.scheme
        self.w_mx_finfo = w_q_param.scheme.dtype.mx_finfo
        self.w_axes = w_q_param.ext.get("axes")

        self.x_scheme = x_q_param.scheme
        self.x_mx_finfo = x_q_param.scheme.dtype.mx_finfo
        self.x_axes = x_q_param.ext.get("axes")

        self.weight_scale = nn.Parameter(w_q_param.ext.get("scale"), requires_grad=False)
        self.weight_offset = nn.Parameter(w_q_param.ext.get("offset"), requires_grad=False)
        self.weight = nn.Parameter(w_q.value, requires_grad=False)
        self.bias = nn.Parameter(bias, requires_grad=False) if bias is not None else None

        minmax_config = MinMaxBlockObserverConfig(axes=self.x_axes)
        self.x_minmax_block_observer = MsMinMaxBlockObserver(minmax_config)

    def __repr__(self) -> str:
        return f"W4A8MXDynamicPerBlockFakeQuantLinear(symmetric={self.w_scheme.symmetric})"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        axes = self.x_axes
        axes = [axes] if isinstance(axes, int) else axes
        axes = [a + x.ndim if a < 0 else a for a in axes]
        x, axes_, orig_shape, padded_shape = reshape_to_blocks(x, axes, self.x_mx_finfo.block_size)
        shared_exp_axes = [x + 1 for x in axes_] if self.x_mx_finfo.block_size > 0 else axes_
        self.x_minmax_block_observer.update(x, shared_exp_axes=shared_exp_axes)
        x_min_val, x_max_val = self.x_minmax_block_observer.get_min_max()
        x_q_param = calculate_qparam(
            x_min_val, x_max_val,
            q_dtype=self.x_scheme.dtype,
            q_scope=self.x_scheme.scope,
            symmetric=self.x_scheme.symmetric
        )
        x_q_dq = fake_quantize(QStorage(QDType.FLOAT, x), x_q_param)
        x_q_dq.value = undo_reshape_to_blocks(x_q_dq.value, padded_shape, orig_shape, axes)

        w_q_storage = QStorage(dtype=self.w_scheme.dtype, value=self.weight)
        axes = self.w_axes
        axes = [axes] if isinstance(axes, int) else axes
        axes = [a + w_q_storage.value.ndim if a < 0 else a for a in axes]
        w_q_storage.value, _, w_orig_shape, w_padded_shape = reshape_to_blocks(
            w_q_storage.value, axes, self.w_mx_finfo.block_size
        )
        w_q_param = QParam(scheme=QScheme(scope=QScope.PER_BLOCK, dtype=QDType.MXFP4, symmetric=True),
                           ext={"scale": self.weight_scale.data})
        weight_q_dq = dequantize(w_q_storage, w_q_param)
        weight_q_dq.value = undo_reshape_to_blocks(weight_q_dq.value, w_padded_shape, w_orig_shape, axes)

        return F.linear(x_q_dq.value, weight_q_dq.value, self.bias)
